import logging
import random
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf

from ..losses.distance_weighting import make_mask_distance_weighter
from ..losses.feature_matching import feature_matching_loss, masked_l1_loss
from ..modules.fake_fakes import FakeFakesGenerator
from .base import BaseInpaintingTrainingModule, make_multiscale_noise
from ...utils import add_prefix_to_keys, get_ramp

LOGGER = logging.getLogger(__name__)


def ceil_modulo(x, mod):
    if x % mod == 0:
        return x
    return (x // mod + 1) * mod


def make_constant_area_crop_params(img_height, img_width, min_size=128, max_size=512, area=256 * 256, round_to_mod=16):
    min_size = min(img_height, img_width, min_size)
    max_size = min(img_height, img_width, max_size)
    if random.random() < 0.5:
        out_height = min(max_size, ceil_modulo(random.randint(min_size, max_size), round_to_mod))
        out_width = min(max_size, ceil_modulo(area // out_height, round_to_mod))
    else:
        out_width = min(max_size, ceil_modulo(random.randint(min_size, max_size), round_to_mod))
        out_height = min(max_size, ceil_modulo(area // out_width, round_to_mod))

    start_y = random.randint(0, img_height - out_height)
    start_x = random.randint(0, img_width - out_width)
    return (start_y, start_x, out_height, out_width)


def make_constant_area_crop_batch(batch, **kwargs):
    crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(
        img_height=batch["image"].shape[2], img_width=batch["image"].shape[3], **kwargs
    )
    batch["image"] = batch["image"][:, :, crop_y : crop_y + crop_height, crop_x : crop_x + crop_width]
    batch["mask"] = batch["mask"][:, :, crop_y : crop_y + crop_height, crop_x : crop_x + crop_width]
    return batch


class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
    def __init__(
        self,
        *args,
        concat_mask=True,
        rescale_scheduler_kwargs=None,
        image_to_discriminator="predicted_image",
        add_noise_kwargs=None,
        noise_fill_hole=False,
        const_area_crop_kwargs=None,
        distance_weighter_kwargs=None,
        distance_weighted_mask_for_discr=False,
        fake_fakes_proba=0,
        fake_fakes_generator_kwargs=None,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.concat_mask = concat_mask
        self.rescale_size_getter = (
            get_ramp(**rescale_scheduler_kwargs) if rescale_scheduler_kwargs is not None else None
        )
        self.image_to_discriminator = image_to_discriminator
        self.add_noise_kwargs = add_noise_kwargs
        self.noise_fill_hole = noise_fill_hole
        self.const_area_crop_kwargs = const_area_crop_kwargs
        self.refine_mask_for_losses = (
            make_mask_distance_weighter(**distance_weighter_kwargs) if distance_weighter_kwargs is not None else None
        )
        self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr

        self.fake_fakes_proba = fake_fakes_proba
        if self.fake_fakes_proba > 1e-3:
            self.fake_fakes_gen = FakeFakesGenerator(**(fake_fakes_generator_kwargs or {}))

    def forward(self, batch):
        if self.training and self.rescale_size_getter is not None:
            cur_size = self.rescale_size_getter(self.global_step)
            batch["image"] = F.interpolate(batch["image"], size=cur_size, mode="bilinear", align_corners=False)
            batch["mask"] = F.interpolate(batch["mask"], size=cur_size, mode="nearest")

        if self.training and self.const_area_crop_kwargs is not None:
            batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs)

        img = batch["image"]
        mask = batch["mask"]

        masked_img = img * (1 - mask)

        if self.add_noise_kwargs is not None:
            noise = make_multiscale_noise(masked_img, **self.add_noise_kwargs)
            if self.noise_fill_hole:
                masked_img = masked_img + mask * noise[:, : masked_img.shape[1]]
            masked_img = torch.cat([masked_img, noise], dim=1)

        if self.concat_mask:
            masked_img = torch.cat([masked_img, mask], dim=1)

        batch["predicted_image"] = self.generator(masked_img)
        batch["inpainted"] = mask * batch["predicted_image"] + (1 - mask) * batch["image"]

        if self.fake_fakes_proba > 1e-3:
            if self.training and torch.rand(1).item() < self.fake_fakes_proba:
                batch["fake_fakes"], batch["fake_fakes_masks"] = self.fake_fakes_gen(img, mask)
                batch["use_fake_fakes"] = True
            else:
                batch["fake_fakes"] = torch.zeros_like(img)
                batch["fake_fakes_masks"] = torch.zeros_like(mask)
                batch["use_fake_fakes"] = False

        batch["mask_for_losses"] = (
            self.refine_mask_for_losses(img, batch["predicted_image"], mask)
            if self.refine_mask_for_losses is not None and self.training
            else mask
        )

        return batch

    def generator_loss(self, batch):
        img = batch["image"]
        predicted_img = batch[self.image_to_discriminator]
        original_mask = batch["mask"]
        supervised_mask = batch["mask_for_losses"]

        # L1
        l1_value = masked_l1_loss(
            predicted_img,
            img,
            supervised_mask,
            self.config.losses.l1.weight_known,
            self.config.losses.l1.weight_missing,
        )

        total_loss = l1_value
        metrics = dict(gen_l1=l1_value)

        # vgg-based perceptual loss
        if self.config.losses.perceptual.weight > 0:
            pl_value = (
                self.loss_pl(predicted_img, img, mask=supervised_mask).sum() * self.config.losses.perceptual.weight
            )
            total_loss = total_loss + pl_value
            metrics["gen_pl"] = pl_value

        # discriminator
        # adversarial_loss calls backward by itself
        mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask
        self.adversarial_loss.pre_generator_step(
            real_batch=img, fake_batch=predicted_img, generator=self.generator, discriminator=self.discriminator
        )
        discr_real_pred, discr_real_features = self.discriminator(img)
        discr_fake_pred, discr_fake_features = self.discriminator(predicted_img)
        adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss(
            real_batch=img,
            fake_batch=predicted_img,
            discr_real_pred=discr_real_pred,
            discr_fake_pred=discr_fake_pred,
            mask=mask_for_discr,
        )
        total_loss = total_loss + adv_gen_loss
        metrics["gen_adv"] = adv_gen_loss
        metrics.update(add_prefix_to_keys(adv_metrics, "adv_"))

        # feature matching
        if self.config.losses.feature_matching.weight > 0:
            need_mask_in_fm = OmegaConf.to_container(self.config.losses.feature_matching).get("pass_mask", False)
            mask_for_fm = supervised_mask if need_mask_in_fm else None
            fm_value = (
                feature_matching_loss(discr_fake_features, discr_real_features, mask=mask_for_fm)
                * self.config.losses.feature_matching.weight
            )
            total_loss = total_loss + fm_value
            metrics["gen_fm"] = fm_value

        if self.loss_resnet_pl is not None:
            resnet_pl_value = self.loss_resnet_pl(predicted_img, img)
            total_loss = total_loss + resnet_pl_value
            metrics["gen_resnet_pl"] = resnet_pl_value

        return total_loss, metrics

    def discriminator_loss(self, batch):
        total_loss = 0
        metrics = {}

        predicted_img = batch[self.image_to_discriminator].detach()
        self.adversarial_loss.pre_discriminator_step(
            real_batch=batch["image"],
            fake_batch=predicted_img,
            generator=self.generator,
            discriminator=self.discriminator,
        )
        discr_real_pred, discr_real_features = self.discriminator(batch["image"])
        discr_fake_pred, discr_fake_features = self.discriminator(predicted_img)
        adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss(
            real_batch=batch["image"],
            fake_batch=predicted_img,
            discr_real_pred=discr_real_pred,
            discr_fake_pred=discr_fake_pred,
            mask=batch["mask"],
        )
        total_loss = total_loss + adv_discr_loss
        metrics["discr_adv"] = adv_discr_loss
        metrics.update(add_prefix_to_keys(adv_metrics, "adv_"))

        if batch.get("use_fake_fakes", False):
            fake_fakes = batch["fake_fakes"]
            self.adversarial_loss.pre_discriminator_step(
                real_batch=batch["image"],
                fake_batch=fake_fakes,
                generator=self.generator,
                discriminator=self.discriminator,
            )
            discr_fake_fakes_pred, _ = self.discriminator(fake_fakes)
            fake_fakes_adv_discr_loss, fake_fakes_adv_metrics = self.adversarial_loss.discriminator_loss(
                real_batch=batch["image"],
                fake_batch=fake_fakes,
                discr_real_pred=discr_real_pred,
                discr_fake_pred=discr_fake_fakes_pred,
                mask=batch["mask"],
            )
            total_loss = total_loss + fake_fakes_adv_discr_loss
            metrics["discr_adv_fake_fakes"] = fake_fakes_adv_discr_loss
            metrics.update(add_prefix_to_keys(fake_fakes_adv_metrics, "adv_"))

        return total_loss, metrics
